import torch
import torch.nn as nn
import codecs
import jieba
from torchvision import transforms
from transformers import AutoTokenizer

from django.shortcuts import render
from django.http import HttpResponse

from jiayan import load_lm
from jiayan import CharHMMTokenizer
from translation.subword_nmt.apply_bpe import BPE
from translation.clf.classifier import Classifier
from fairseq.models.transformer import TransformerModel


device="cuda" if torch.cuda.is_available() else "cpu"
tokenizer_clf=AutoTokenizer.from_pretrained("bert-base-chinese")
lm = load_lm('./translation/jiayan_models/jiayan.klm')

AMmodel=TransformerModel.from_pretrained("./translation",checkpoint_file="./model/an-zh.pt",data_name_or_path="./data-bin")
MAmodel=TransformerModel.from_pretrained("./translation",checkpoint_file="./model/zh-an.pt",data_name_or_path="./data-bin_zh-an")
model=Classifier(freeze_bert=True)
model.load_state_dict(torch.load('./translation/model/clf_model.pth',map_location='cuda:0'),strict=False)

language_dict={'0':'自动检测','1':'文言文','2':'现代文'}

def classifier(model,txt,tokenizer_clf):
    # device="cuda" if torch.cuda.is_available() else "cpu"
    # model = model.to(device)
    token=tokenizer_clf(txt,return_tensors="pt")["input_ids"]
    output=model(token)
    print(output)
    if output[0,0]>output[0,1]:
        return 0
    else:
        return 1

def translate(src_list,flag):
    output_end=""

    if not flag:
        codes = codecs.open("./translation/file/bpecode.an",encoding='utf-8')
    else:
        codes = codecs.open("./translation/file/bpecode.zh",encoding='utf-8')
    for input in src_list:
        input = input + "。"
        if not flag:
            tokenizer = CharHMMTokenizer(lm)
            ls = list(tokenizer.tokenize(input))
            tokenizer = " ".join(ls)
        else:
            tokenizer = list(jieba.cut(input))
            tokenizer = " ".join(tokenizer)

        bpe = BPE(codes)
        res = bpe.process_line(tokenizer)

        if not flag:
            aftermodel = AMmodel.translate(res)
        else:
            aftermodel = MAmodel.translate(res)
    
        proc1 = aftermodel.split(" ")
        proc2 = [i.strip("@") for i in proc1]
        output = "".join(proc2)
        output_end = output_end+output

    codes.close()
    return output_end

def index(request):
    input_text='请输入您要翻译的内容'
    result='翻译结果'
    content={}
    if request.method == 'POST':
        input_text=request.POST.get('input')
        source_id=request.POST.get('source_language')
        target_id=request.POST.get('target_language')

        src=input_text.strip('。')
        src_list = src.split("。")

        if source_id=='0':
            flag = classifier(model,input_text,tokenizer_clf)
            result=translate(src_list,flag)
            if not flag:
                target_id='2'
            else:
                target_id='1'
        elif source_id=='1':
            if target_id=='1':
                result=input_text
            else:
                result=translate(src_list,0)
        elif source_id=='2':
            if target_id=='2':
                result=input_text
            else:
                result=translate(src_list,1)

        content={"textin":input_text,"textout":result,
                "source_language":source_id,"target_language":target_id}   
        print(content)
    return render(request, 'translation.html',content)

def about_us(request):
    return render(request,'about_us.html')

def reset(request):
    return render(request,'translate.html')